{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(80000, 3) (20000, 3)\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "train = pd.read_table('../datasets/ml-100k/u1.base',\n",
    "                      sep='\\t', header=None).iloc[:, :3].values\n",
    "test = pd.read_table('../datasets/ml-100k/u1.test',\n",
    "                     sep='\\t', header=None).iloc[:, :3].values\n",
    "n_users, n_items = 943+1, 1682+1    # 数据idx从1开始\n",
    "n_samples = train.shape[0]\n",
    "\n",
    "print(train.shape, test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型基本\n",
    "## 参数设定"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 20    # 隐因子数量\n",
    "\n",
    "glob_mean = np.mean(train[:, 2])    # 全局均分\n",
    "\n",
    "bi = np.zeros(n_items)\n",
    "bu = np.zeros(n_users)\n",
    "qi = np.zeros((n_items, k))\n",
    "pu = np.zeros((n_users, k))\n",
    "\n",
    "# 查询用字典，避免生成大型稀疏矩阵\n",
    "item_user_dict = dict()\n",
    "user_item_dict = dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "for sample in train:\n",
    "    user_id,item_id,rating=sample\n",
    "    item_user_dict.setdefault(item_id,{})\n",
    "    user_item_dict.setdefault(user_id,{})\n",
    "    \n",
    "    item_user_dict[item_id][user_id]=rating\n",
    "    user_item_dict[user_id][item_id]=rating"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 1.034226696186526\n",
      "1 0.9193063490388131\n",
      "2 0.8905347580429747\n",
      "3 0.8771195885444455\n",
      "4 0.8688715384079316\n",
      "5 0.8638050550667303\n",
      "6 0.8600667901603428\n",
      "7 0.8577347026688281\n",
      "8 0.8552588375611702\n",
      "9 0.8541307628195199\n",
      "10 0.8527736487592168\n",
      "11 0.8518794148114602\n",
      "12 0.8508790143733819\n",
      "13 0.8502341996279189\n",
      "14 0.8494970644918951\n",
      "15 0.8490790063503307\n",
      "16 0.8486372883947207\n",
      "17 0.8482090808401567\n",
      "18 0.8477907383287997\n",
      "19 0.8472711192077239\n"
     ]
    }
   ],
   "source": [
    "max_iter = 20    # 迭代次数\n",
    "lr = 0.01    # 学习率\n",
    "alpha = 0.1    # 正则项系数\n",
    "\n",
    "for epoch in range(max_iter):\n",
    "    MSE = 0\n",
    "    random_idxs = np.random.permutation(n_samples)\n",
    "\n",
    "    for idx in random_idxs:\n",
    "        user_id, item_id, rating = train[idx]\n",
    "        y_pred = glob_mean+bi[item_id]+bu[user_id] + \\\n",
    "            np.dot(pu[user_id], qi[item_id].T)\n",
    "        err = rating-y_pred\n",
    "        MSE += err**2\n",
    "\n",
    "        bu[user_id] += lr*(err-alpha*bu[user_id])\n",
    "        bi[item_id] += lr*(err-alpha*bi[item_id])\n",
    "        tmp = qi[item_id]\n",
    "        qi[item_id] += lr*(err*pu[user_id]-alpha*qi[item_id])\n",
    "        pu[user_id] += lr*(err*tmp-alpha*pu[user_id])\n",
    "\n",
    "    MSE /= n_samples\n",
    "    print(epoch, MSE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试误差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9148970313988244\n"
     ]
    }
   ],
   "source": [
    "Y_pred = list()\n",
    "test_mse = 0\n",
    "for sample in test:\n",
    "    user_id, item_id, rating = sample\n",
    "    y_pred = glob_mean+bi[item_id]+bu[user_id] + \\\n",
    "        np.dot(pu[user_id], qi[item_id].T)\n",
    "    test_mse += (rating-y_pred)**2\n",
    "test_mse /= len(test)\n",
    "\n",
    "print(test_mse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
